# -*- coding: utf-8 -*-
"""pipeline.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1-m2ywJVcfgCHOcEN-4agAbLz7tRGqMvM
"""

'''准备模型和数据'''
'''这里模型就用model这个变量'''
'''数据之后用example作为演示'''
'''使用的时候替换成自己的model就可以了'''
#import numpy as np
import torch
import config
#from activate_neuron.mymodel import *
#import activate_neuron.mymodel as mymodel
#from activate_neuron.utils import *
#import activate_neuron.utils as utils


#from transformers import AutoConfig, AutoModelForMaskedLM
#from model.modelling_roberta import RobertaForMaskedLM
#from reader.reader import init_dataset, init_formatter, init_test_dataset

import argparse
import os
import torch
import logging
import random
import numpy as np

from tools.init_tool import init_all
from config_parser import create_config
from tools.valid_tool import valid
from torch.autograd import Variable

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

logger = logging.getLogger(__name__)

def set_random_seed(seed):
    """Set random seed for reproducability."""

    if seed is not None and seed > 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)



def relu(tmp):
    return 1*(tmp > 0)*tmp

def topk(obj, k):
    M=-10000
    obj = list(obj)[:]
    idlist = []
    for i in range(k):
        idlist.append(obj.index(max(obj)))
        obj[obj.index(max(obj))]=M
    return idlist

def relu(tmp):
    return 1*(tmp > 0)*tmp

def topk(obj, k):
    M=-10000
    obj = list(obj)[:]
    idlist = []
    for i in range(k):
        idlist.append(obj.index(max(obj)))
        obj[obj.index(max(obj))]=M
    return idlist




if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', help="specific config file", required=True)
    parser.add_argument('--gpu', '-g', help="gpu id list")
    parser.add_argument('--local_rank', type=int, help='local rank', default=-1)
    parser.add_argument('--do_test', help="do test while training or not", action="store_true")
    parser.add_argument('--checkpoint', help="checkpoint file path", type=str, default=None)
    parser.add_argument('--comment', help="checkpoint file path", default=None)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--prompt_emb_output", type=bool, default=False)
    parser.add_argument("--save_name", type=str, default=None)
    parser.add_argument("--replacing_prompt", type=str, default=None)
    parser.add_argument("--pre_train_mlm", default=False, action='store_true')
    parser.add_argument("--task_transfer_projector", default=False, action='store_true')
    parser.add_argument("--model_transfer_projector", default=False, action='store_true')
    parser.add_argument("--activate_neuron", default=True, action='store_true')
    parser.add_argument("--mode", type=str, default="valid")
    parser.add_argument("--projector", type=str, default=None)


    args = parser.parse_args()
    configFilePath = args.config


    config = create_config(configFilePath)



    use_gpu = True
    gpu_list = []
    if args.gpu is None:
        use_gpu = False
    else:
        use_gpu = True
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

        device_list = args.gpu.split(",")
        for a in range(0, len(device_list)):
            gpu_list.append(int(a))

    os.system("clear")
    config.set('distributed', 'local_rank', args.local_rank)
    config.set("distributed", "use", False)
    if config.getboolean("distributed", "use") and len(gpu_list)>1:
        torch.cuda.set_device(gpu_list[args.local_rank])
        torch.distributed.init_process_group(backend=config.get("distributed", "backend"))
        config.set('distributed', 'gpu_num', len(gpu_list))

    cuda = torch.cuda.is_available()
    logger.info("CUDA available: %s" % str(cuda))
    if not cuda and len(gpu_list) > 0:
        logger.error("CUDA is not available but specific gpu id")
        raise NotImplementedError
    set_random_seed(args.seed)


    ########
    '''
    formatter = "mlmPrompt"
    config.set("data","train_formatter_type",formatter)
    config.set("data","valid_formatter_type",formatter)
    config.set("data","test_formatter_type",formatter)
    config.set("model","model_name","mlmPrompt")
    '''
    ########



    parameters = init_all(config, gpu_list, args.checkpoint, args.mode, local_rank = args.local_rank, args=args)
    do_test = False

    model = parameters["model"]
    valid_dataset = parameters["valid_dataset"]


    ##########################
    ##########################


    '''准备hook'''
    '''这是提取特征的代码'''
    outputs=[[] for _ in range(12)]
    def save_ppt_outputs1_hook(n):
        def fn(_,__,output):
            #print("=====")
            #print(output)
            #print("----")
            #print(output.shape) #torch.Size([1, 1, 3072])
            #print("=====")
            #exit()
            outputs[n].append(output.detach().to("cpu"))
            #outputs[n].append(output.detach())
        return fn


    for n in range(12):
        #这里面提取feature的模组可以改变，这里因为我自定义模型的原因要两层roberta
        #for l in model.state_dict().keys():
        #    print(l)
        #print("====")
        #exit()

        #decoder
        model.encoder.decoder.block[n].layer[2].DenseReluDense.wi.register_forward_hook(save_ppt_outputs1_hook(n))

        #encoder
        #model.encoder.encoder.block[n].layer[1].DenseReluDense.wi.register_forward_hook(save_ppt_outputs1_hook(n))






    '''将数据通过模型'''
    '''hook会自动将中间层的激活储存在outputs中'''
    model.eval()
    valid(model, parameters["valid_dataset"], 1, None, config, gpu_list, parameters["output_function"], mode=args.mode, args=args)


    #################################################
    #################################################
    #################################################


    '''
    print(len(outputs)) #12
    print(len(outputs[0])) #17 epoch
    print(len(outputs[0][0])) #64
    print(len(outputs[0][0][0])) #231
    print(len(outputs[0][0][0][0])) #3072
    #outputs[][][][][] , layer:12, epoch:17, batch_size:64, input_length:231, neuron:3072
    '''

    #merge 17 epoch
    for k in range(12):
        #outputs[k] = relu(np.concatenate(outputs[k]))
        #outputs[k] = torch.relu(torch.cat(outputs[k]))
        outputs[k] = torch.cat(outputs[k])
        #print(outputs[k])
        #print(outputs[k].shape)
        #exit()


    '''
    print(len(outputs)) #12
    print(len(outputs[0])) #17 epoch
    print(len(outputs[0][0])) #64
    print(len(outputs[0][0][0])) #231
    print(len(outputs[0][0][0][0])) #3072
    #outputs[][][][][] , layer:12, epoch:17, batch_size:64, input_length:231, neuron:3072
    '''


    '''这部分是根据论文里的代码找到某个neuron的最大激活'''
    '''
    #划定层数
    #layer = np.random.randint(12)
    layer = torch.randint(1,12,(1,))
    #决定neuron
    #neuron = np.random.randint(3072)
    neuron = torch.randint(1,3072,(1,))
    #这里面是得到了某层的某个neuron的所有激活
    neuron_activation = outputs[layer][:,:,neuron]
    max_activation = [neuron_activation[i,:length[i]].max() for i in range(size)]
    print(neuron_activation)
    print(max_activation)
    exit()
    '''



    outputs = torch.stack(outputs)

    #decoder
    outputs = outputs[:,:1,:1,:] #12 layers, [mask]

    #encoder
    #outputs = outputs[:,:,100:101,:] #12 layers, [mask]

    #print(outputs.shape)
    # [12, 1, 1, 3072] --> 12, 1(batch_size), (target_length), 3072

    # [12, 2, 1, 3072] --> 12, 1(batch_size), (target_length), 3072


    #print(outputs)
    #print(save_dir)
    #exit()


    save_name = args.replacing_prompt.strip().split("/")[-1].split(".")[0]
    #print(save_name)
    #exit()
    dir = "task_activated_neuron"
    if os.path.isdir(dir):
        save_dir = dir+"/"+save_name
        if os.path.isdir(save_dir):
            torch.save(outputs,save_dir+"/task_activated_neuron")
        else:
            os.mkdir(save_dir)
            torch.save(outputs,save_dir+"/task_activated_neuron")
    else:
        os.mkdir(dir)
        save_dir = dir+"/"+save_name
        os.mkdir(save_dir)
        torch.save(outputs,save_dir+"/task_activated_neuron")


    print("==Prompt emb==")
    print(outputs.shape)
    print("Save Done")
    print("==============")










    '''
    size = 8 # number of the sentences
    length = 231 #sentence length
    #Activated neuron for a task-specific prompt
    for layer in range(1,12):
        for neuron in range(1,3072):
            neuron_activation = outputs[layer][:,:,neuron]
            print(outputs[layer].shape)
            print(neuron_activation.shape)
            exit()
            max_activation = [neuron_activation[i,:length[i]].max() for i in range(size)]
            print(neuron_activation)
            print("------------")
            print(max_activation)
            print("============")
    exit()
    '''



    '''选择头几个句子展示'''
    '''
    N = 4
    indexes = topk(max_activation,N)
    for ids in indexes:
        print(tokenizer.decode(example['input_ids'][ids,:length[ids]]))
    '''
